package org.geekbang.consistenthash;

import java.util.*;

/**
 * 带虚拟节点的一致性hash算法路由器实现
 * 虚拟节点命名规则：IP#序列号
 */
public class VirtualNodeConsistentHashServerRouter implements ServerRouter
{
    /**
     * 服务器列表
     */
    private List<String> servers = new LinkedList<>();

    /**
     * 每个服务器对应多少个虚拟节点
     */
    private Integer virtualNodesPerServer = 0;

    /**
     * 用TreeMap表示环
     */
    private SortedMap<Integer, String> ring = new TreeMap<>();

    /**
     * 统计每个服务器节点路由到的key数量
     */
    private Map<String, Integer> serverStatisic = new HashMap<>();

    public VirtualNodeConsistentHashServerRouter(Integer virtualNodesPerServer)
    {
        this.virtualNodesPerServer = virtualNodesPerServer;
    }

    /**
     * 将服务器节点和虚拟节点放入环中
     *
     * @param server
     */
    private void addServerNodeToRing(String server)
    {
        //添加服务器节点到环
        ring.put(getHash(server), server);
        //添加服务器虚拟节点到环
        for (int i = 0; i < virtualNodesPerServer; i++)
        {
            String virtualServer = server + "#" + i;
            ring.put(getHash(virtualServer), virtualServer);
        }
    }

    /**
     * 将服务器节点和虚拟节点从环中移除
     *
     * @param server
     */
    private void removeServerNodeFromRing(String server)
    {
        //将服务器节点从环中移除
        ring.remove(getHash(server));
        //将服务器虚接点从环中移除
        for (int i = 0; i < virtualNodesPerServer; i++)
        {
            String virtualServer = server + "#" + i;
            ring.remove(getHash(virtualServer));
        }
    }

    /**
     * 服务器路由到的key统计数加1
     *
     * @param server
     */
    private void increCount(String server)
    {
        Integer count = serverStatisic.get(server);
        serverStatisic.put(server, ++count);
    }

    @Override
    public void addServer(String server)
    {
        servers.add(server);
        this.addServerNodeToRing(server);
        serverStatisic.put(server, 0);
    }

    @Override
    public void addServers(Collection<String> servers)
    {
        for (String server : servers)
        {
            addServer(server);
        }
        //printRing();
    }

    /**
     * 打印换环
     */
    private void printRing()
    {
        System.out.println("ring table:");
        for (Map.Entry<Integer, String> entry : ring.entrySet())
        {
            System.out.println(entry.getKey() + "->" + entry.getValue() + "->" + getServer(entry.getValue()));
        }
    }

    @Override
    public void removeServer(String server)
    {
        servers.remove(server);
        this.removeServerNodeFromRing(server);
    }

    @Override
    public List<String> getServers()
    {
        return servers;
    }

    @Override
    public String getServer(String key)
    {
        String node = getNode(key);
        String server = null;
        //根据节点名称转换成对应服务器名称
        if (!node.contains("#"))
        {
            //不带#的为服务器真实节点
            server = node;
        } else
        {
            //带#的为服务器虚拟节点
            String[] strs = node.split("#");
            server = strs[0];
        }
        //路由到该服务器的key统计数量+1
        this.increCount(server);
        return server;
    }

    /**
     * 根据key得到节点（可能是服务器节点，也可能是服务器虚拟节点）
     *
     * @param key
     * @return
     */
    private String getNode(String key)
    {
        SortedMap<Integer, String> tailMap = ring.tailMap(getHash(key));
        if (tailMap.isEmpty())
        {
            return ring.get(ring.firstKey());
        } else
        {
            return ring.get(tailMap.firstKey());
        }
    }

    @Override
    public int getStandardDeviation()
    {
        double sum = 0;
        for (double value : serverStatisic.values())
        {
            sum += value;
        }
        //计算平均值
        double avg = sum / serverStatisic.values().size();
        double sumOfSquares = 0;
        for (double value : serverStatisic.values())
        {
            //System.out.println("value:" + value);
            sumOfSquares += Math.pow(value - avg, 2);
        }
        double standardDeviation = Math.sqrt(sumOfSquares / serverStatisic.values().size());
        return (int) standardDeviation;
    }

    /**
     * 重新实现Hash值(String的默认hash算法太连续)
     *
     * @param str
     * @return
     */
    private static int getHash(String str)
    {
        final int p = 16777619;
        int hash = (int) 2166136261L;
        for (int i = 0; i < str.length(); i++)
        {
            hash = (hash ^ str.charAt(i)) * p;
        }
        hash += hash << 13;
        hash ^= hash >> 7;
        hash += hash << 3;
        hash ^= hash >> 17;
        hash += hash << 5;

        // 如果算出来的值为负数则取其绝对值
        if (hash < 0)
        {
            hash = Math.abs(hash);
        }
        return hash;
    }
}